Multi-Head Attention理解和PyTorch实现

一、基础版

SelfAttention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
def __init__(self, embed_size):
super(SelfAttention, self).__init__()
self.embed_size = embed_size

# 定义三个线性层分别生成Q, K, V
self.query = nn.Linear(embed_size, embed_size)
self.key = nn.Linear(embed_size, embed_size)
self.value = nn.Linear(embed_size, embed_size)

def forward(self, x):
# x形状: (batch_size, seq_length, embed_size)
batch_size, seq_len, _ = x.size()

# 生成Q, K, V [均为(batch_size, seq_len, embed_size)]
Q = self.query(x)
K = self.key(x)
V = self.value(x)

# 计算注意力分数
# (batch_size, seq_len, seq_len)
scores = torch.bmm(Q, K.transpose(1, 2)) / (self.embed_size ** 0.5)

# 应用softmax得到注意力权重
attention_weights = F.softmax(scores, dim=-1)

# 应用注意力权重到V
# (batch_size, seq_len, embed_size)
output = torch.bmm(attention_weights, V)

return output, attention_weights

# 示例用法
if __name__ == "__main__":
# 超参数
embed_size = 4 # 嵌入维度
seq_len = 3 # 序列长度
batch_size = 1 # 批大小

# 创建自注意力模块
sa = SelfAttention(embed_size)

# 创建随机输入 (模拟一个batch的输入)
x = torch.rand(batch_size, seq_len, embed_size)

# 前向传播
output, attention_weights = sa(x)

print("输入形状:", x.shape)
print("输出形状:", output.shape)
print("注意力权重形状:", attention_weights.shape)
print("\n注意力权重矩阵:")
print(attention_weights.squeeze().detach())

Multi-Head Attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
def __init__(self, embed_size=512, num_heads=8):
super(MultiHeadAttention, self).__init__()
self.embed_size = embed_size
self.num_heads = num_heads
self.head_dim = embed_size // num_heads

# 确保embed_size可以被num_heads整除
assert self.head_dim * num_heads == embed_size, "Embed size needs to be divisible by num_heads"

# 定义四个线性层(Q, K, V和最终输出)
self.query = nn.Linear(embed_size, embed_size)
self.key = nn.Linear(embed_size, embed_size)
self.value = nn.Linear(embed_size, embed_size)
self.fc_out = nn.Linear(embed_size, embed_size)

def forward(self, x):
batch_size, seq_len, _ = x.shape

# 生成Q, K, V [shape: (batch_size, seq_len, embed_size)]
Q = self.query(x)
K = self.key(x)
V = self.value(x)

# 分割多头:将embed_size维度拆分为num_heads x head_dim
# 新形状: (batch_size, seq_len, num_heads, head_dim)
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)

# 调整维度顺序用于矩阵乘法
# 新形状: (batch_size, num_heads, seq_len, head_dim)
Q = Q.permute(0, 2, 1, 3)
K = K.permute(0, 2, 1, 3)
V = V.permute(0, 2, 1, 3)

# 计算缩放点积注意力分数
scores = torch.matmul(Q, K.permute(0, 1, 3, 2)) / (self.head_dim ** 0.5)
# scores形状: (batch_size, num_heads, seq_len, seq_len)

# 应用softmax得到注意力权重
attention_weights = F.softmax(scores, dim=-1)

# 应用注意力权重到V
weighted = torch.matmul(attention_weights, V)
# weighted形状: (batch_size, num_heads, seq_len, head_dim)

# 合并多头结果
# 1. 调整维度顺序 (batch_size, seq_len, num_heads, head_dim)
weighted = weighted.permute(0, 2, 1, 3)
# 2. 合并最后两个维度 (batch_size, seq_len, embed_size)
weighted = weighted.contiguous().view(batch_size, seq_len, self.embed_size)

# 通过最终线性层
output = self.fc_out(weighted)

return output, attention_weights

# 示例用法
if __name__ == "__main__":
# 参数设置
embed_size = 8 # 总嵌入维度
num_heads = 2 # 注意力头数量
seq_len = 4 # 序列长度
batch_size = 1 # 批大小

# 创建多头注意力模块
mha = MultiHeadAttention(embed_size, num_heads)

# 创建随机输入 (模拟一个batch的输入)
x = torch.rand(batch_size, seq_len, embed_size)

# 前向传播
output, attention_weights = mha(x)

print("输入形状:", x.shape)
print("输出形状:", output.shape)
print("注意力权重形状:", attention_weights.shape)
print("\n第一个头的注意力矩阵:")
print(attention_weights[0, 0].detach().numpy().round(3))

二、Mask版

SelfAttention With Mask

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
def __init__(self, embed_size):
super(SelfAttention, self).__init__()
self.embed_size = embed_size

# 定义三个线性层分别生成Q, K, V
self.query = nn.Linear(embed_size, embed_size)
self.key = nn.Linear(embed_size, embed_size)
self.value = nn.Linear(embed_size, embed_size)

def forward(self, x, mask=None):
"""
参数:
x: 输入张量 (batch_size, seq_len, embed_size)
mask: 掩码张量,支持两种类型:
- Padding mask (batch_size, 1, 1, seq_len)
- Sequence mask (batch_size, 1, seq_len, seq_len)
"""
# x形状: (batch_size, seq_length, embed_size)
batch_size, seq_len, _ = x.size()

# 生成Q, K, V [均为(batch_size, seq_len, embed_size)]
Q = self.query(x)
K = self.key(x)
V = self.value(x)

# 计算注意力分数
# (batch_size, seq_len, seq_len)
scores = torch.bmm(Q, K.transpose(1, 2)) / (self.embed_size ** 0.5)

# 应用掩码(如果存在)
if mask is not None:
"""
掩码逻辑:
- 对于需要屏蔽的位置,设置其分数为极小的值(-1e9)
- 注意mask的形状需要与scores兼容
"""
# 调整mask形状 (如果是2D则扩展为3D)
if mask.dim() == 2:
mask = mask.unsqueeze(1) # (batch_size, 1, seq_len)
elif mask.dim() == 3:
mask = mask.unsqueeze(1) # (batch_size, 1, seq_len, seq_len)

# 应用mask到注意力分数
scores = scores.masked_fill(mask == 0, -1e9)

# 应用softmax得到注意力权重
attention_weights = F.softmax(scores, dim=-1)

# 应用注意力权重到V
# (batch_size, seq_len, embed_size)
output = torch.bmm(attention_weights, V)

return output, attention_weights

# 示例用法
if __name__ == "__main__":
# 超参数
embed_size = 4 # 嵌入维度
seq_len = 3 # 序列长度
batch_size = 1 # 批大小

# 创建自注意力模块
sa = SelfAttention(embed_size)

# 测试1: 基础用法(无mask)
x = torch.rand(batch_size, seq_len, embed_size)
output, attn = sa(x)
print("无mask的注意力权重:")
print(attn.squeeze().detach().numpy().round(3))

# 测试2: 应用padding mask
# 假设第二个位置是padding
padding_mask = torch.tensor([[1, 0, 1]], dtype=torch.float32) # (batch_size, seq_len)
_, attn_pad = sa(x, mask=padding_mask)
print("\n带padding mask的注意力权重:")
print(attn_pad.squeeze().detach().numpy().round(3))

# 测试3: 应用sequence mask(解码器用)
sequence_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
sequence_mask = sequence_mask.unsqueeze(0) # (1, seq_len, seq_len)
_, attn_seq = sa(x, mask=sequence_mask)
print("\n带sequence mask的注意力权重:")
print(attn_seq.squeeze().detach().numpy().round(3))

Multi-Head Attention With Mask

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
def __init__(self, embed_size=512, num_heads=8):
super(MultiHeadAttention, self).__init__()
self.embed_size = embed_size
self.num_heads = num_heads
self.head_dim = embed_size // num_heads

assert self.head_dim * num_heads == embed_size, "Embed size must be divisible by num_heads"

# 定义四个线性层
self.query = nn.Linear(embed_size, embed_size)
self.key = nn.Linear(embed_size, embed_size)
self.value = nn.Linear(embed_size, embed_size)
self.fc_out = nn.Linear(embed_size, embed_size)

def forward(self, x, mask=None):
"""
参数:
x: 输入张量 (batch_size, seq_len, embed_size)
mask: 支持两种掩码类型:
- Padding mask: (batch_size, 1, 1, seq_len)
- Sequence mask: (batch_size, 1, seq_len, seq_len)
"""
batch_size, seq_len, _ = x.shape

# 生成Q, K, V
Q = self.query(x)
K = self.key(x)
V = self.value(x)

# 分割多头 [batch_size, seq_len, num_heads, head_dim]
Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
K = K.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
V = V.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

# 计算注意力分数 [batch_size, num_heads, seq_len, seq_len]
scores = torch.matmul(Q, K.permute(0, 1, 3, 2)) / (self.head_dim ** 0.5)

# 应用掩码
if mask is not None:
"""
掩码处理规则:
1. 自动适配不同维度的mask输入
2. 将mask转换为布尔类型
3. 在注意力分数上应用mask
"""
# 调整mask维度
if mask.dim() == 2: # (batch_size, seq_len) → Padding mask
mask = mask.unsqueeze(1).unsqueeze(1) # [batch_size, 1, 1, seq_len]
elif mask.dim() == 3: # (batch_size, seq_len, seq_len) → Sequence mask
mask = mask.unsqueeze(1) # [batch_size, 1, seq_len, seq_len]

# 确保mask数据类型为布尔型
mask = mask.to(torch.bool)

# 应用mask(自动广播到所有注意力头)
scores = scores.masked_fill(mask, -1e9)

# 计算注意力权重
attention_weights = F.softmax(scores, dim=-1)

# 计算上下文向量
context = torch.matmul(attention_weights, V)

# 合并多头 [batch_size, seq_len, embed_size]
context = context.permute(0, 2, 1, 3).contiguous()
context = context.view(batch_size, seq_len, self.embed_size)

# 最终线性变换
output = self.fc_out(context)

return output, attention_weights

# 示例用法
if __name__ == "__main__":
# 参数设置
embed_size = 8
num_heads = 2
seq_len = 4
batch_size = 1

# 初始化模块
mha = MultiHeadAttention(embed_size, num_heads)

# 测试输入
x = torch.rand(batch_size, seq_len, embed_size)

print("===== 测试1: 无掩码 =====")
out, attn = mha(x)
print(f"注意力权重形状: {attn.shape}")
print("第一个头的注意力矩阵:")
print(attn[0, 0].detach().numpy().round(3))

print("\n===== 测试2: 应用Padding掩码 =====")
padding_mask = torch.tensor([[1, 1, 0, 0]], dtype=torch.bool) # 最后两个位置是padding
_, attn_pad = mha(x, mask=padding_mask)
print("带padding mask的注意力矩阵:")
print(attn_pad[0, 0].detach().numpy().round(3))

print("\n===== 测试3: 应用Sequence掩码 =====")
seq_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
seq_mask = seq_mask.unsqueeze(0) # 添加batch维度
_, attn_seq = mha(x, mask=seq_mask)
print("带sequence mask的注意力矩阵:")
print(attn_seq[0, 0].detach().numpy().round(3))

三、统一实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
def __init__(self, embed_size=512, num_heads=8):
self.embed_size = embed_size
self.num_heads = num_heads
self.head_dim = embed_size // num_heads
assert self.head_dim * num_heads == embed_size, "Embed size must be divisible by num_heads"

# 定义QKV
self.Q = nn.Linear(embed_size, embed_size)
self.K = nn.Linear(embed_size, embed_size)
self.V = nn.Linear(embed_size, embed_size)
self.scale = torch.sqrt(self.head_dim)

# 定义softmax
self.softmax = nn.Softmax(dim=-1)

def forward(self, X, mask=None):
"""
X: [batch_size, seq_len, emb_dim]
mask:
- [batch_size, seq_len] for padding mask
- [batch_size, seq_len, seq_len] for causual mask
"""
# 单头注意力机制
if self.num_heads <= 1:
query = self.Q(X)
key = self.K(X)
value = self.V(X) # [batch_size, seq_len, emb_dim]

scores = torch.bmm(query, key.transpose(1, 2)) / self.scale # [batch_size, seq_len, seq_len]
if mask is not None:
if mask.dim() == 2:
mask = mask.unsqueeze(dim=1) # [batch_size, 1, seq_len]
elif mask.dim() == 3:
pass
else:
raise ValueError(f"The dim of mask must be 2 or 3!")
scores = scores.mask_fill(mask==0, -1e9) # 利用广播
attention_weights = self.softmax(scores) # [batch_size, seq_len, seq_len]
output = torch.bmm(attention_weights, value) # [batch_size, seq_len, emb_dim]
return output
# 多头注意力机制
else:
batch_size, seq_len, _emb_dim = X.size()
query = self.Q(X).view(batch_size, -1, self.num_heads, self.head_dim)
key = self.K(X).view(batch_size, -1, self.num_heads, self.head_dim)
value = self.V(X).view(batch_size, -1, self.num_heads, self.head_dim) # [batch_size, seq_len, heads, head_dim]

scores = torch.matmul(query.permute(0, 2, 1, 3), key.permute(0, 2, 3, 1)) / self.scale # [batch_size, heads, seq_len, seq_len]
if mask is not None:
if mask.dim() == 2:
mask = mask.unsqueeze(dim=1).unsqueeze(dim=1) # [batch_size, 1, 1, seq_len]
elif mask.dim() == 3:
mask = mask.unsqueeze(dim=1) # [batch_size, 1, seq_len, seq_len]
else:
raise ValueError(f"The dim of mask must be 2 or 3!")
scores = scores.mask_fill(mask==0, -1e9) # 利用广播
attention_weights = self.softmax(scores) # [batch_size, heads, seq_len, seq_len]
output = torch.matmul(attention_weights, value.permute(0, 2, 1, 3)) # [batch_size, heads, seq_len, head_dim]
output = output.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, self.embed_size)
return output

参考资料